package com.chatplus.application.controller.api;

import cn.dev33.satoken.annotation.SaIgnore;
import cn.hutool.core.bean.BeanUtil;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.conditions.update.LambdaUpdateWrapper;
import com.chatplus.application.aiprocessor.handler.ChatWebSocketHandler;
import com.chatplus.application.common.exception.BadRequestException;
import com.chatplus.application.common.util.MessageTokenUtil;
import com.chatplus.application.constant.Constants;
import com.chatplus.application.domain.entity.account.UserEntity;
import com.chatplus.application.domain.entity.chat.ChatHistoryEntity;
import com.chatplus.application.domain.entity.chat.ChatItemEntity;
import com.chatplus.application.domain.entity.chat.ChatRoleEntity;
import com.chatplus.application.domain.request.ChatItemUpdateRequest;
import com.chatplus.application.domain.request.ChatTokenRequest;
import com.chatplus.application.domain.response.ChatHistoryDetailResponse;
import com.chatplus.application.domain.response.ChatItemDetailResponse;
import com.chatplus.application.service.account.UserService;
import com.chatplus.application.service.chat.ChatHistoryService;
import com.chatplus.application.service.chat.ChatItemService;
import com.chatplus.application.service.chat.ChatRoleService;
import com.chatplus.application.web.basecontroller.BaseController;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.bind.annotation.*;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static cn.dev33.satoken.stp.StpUtil.isLogin;

/**
 * AI 会话
 */
@RestController
@RequestMapping("/api/chat")
public class ChatItemApiController extends BaseController {

    private final ChatItemService chatItemService;
    private final ChatRoleService chatRoleService;
    private final ChatHistoryService chatHistoryService;
    private final ChatWebSocketHandler chatWebSocketHandler;
    private final UserService userService;

    public ChatItemApiController(ChatItemService chatItemService
            , ChatRoleService chatRoleService
            , ChatWebSocketHandler chatWebSocketHandler
            , ChatHistoryService chatHistoryService
            , UserService userService) {
        this.chatItemService = chatItemService;
        this.chatRoleService = chatRoleService;
        this.chatHistoryService = chatHistoryService;
        this.chatWebSocketHandler = chatWebSocketHandler;
        this.userService = userService;
    }

    @PostMapping("/tokens")
    public Long token(@RequestBody ChatTokenRequest request) {
        // 如果没有传入 text 字段，则说明是获取当前 reply 总的 token 消耗（带上下文）
        if (StringUtils.isNotEmpty(request.getText())) {
            return MessageTokenUtil.countMessageTokens(request.getText(), request.getModel());
        }
        ChatHistoryEntity entity = chatHistoryService.getLastReplyHistoryByUserId(getUserId(), request.getChatId());
        if (entity != null) {
            return entity.getTokens();
        }
        return 0L;
    }

    @GetMapping("/list")
    @SaIgnore
    public List<ChatItemDetailResponse> list() {
        if (!isLogin()) {
            return Collections.emptyList();
        }
        List<ChatItemEntity> list = chatItemService.getChatItemByUserId(getUserId());
        return list.stream().map(item -> {
            ChatItemDetailResponse response = new ChatItemDetailResponse();
            response.setId(item.getId());
            ChatRoleEntity chatRoleEntity = chatRoleService.getById(item.getRoleId());
            if (chatRoleEntity != null) {
                response.setIcon(chatRoleEntity.getIcon());
            }
            response.setUserId(item.getUserId());
            response.setRoleId(item.getRoleId());
            response.setTitle(item.getTitle());
            response.setModelId(item.getModelId());
            response.setChatId(item.getChatId());
            response.setCreatedAt(item.getCreatedAt().getEpochSecond());
            response.setUpdatedAt(item.getUpdatedAt().getEpochSecond());
            return response;
        }).toList();
    }

    @GetMapping("/detail")
    public ChatHistoryDetailResponse detail(@RequestParam(value = "chat_id") String chatId) {
        LambdaQueryWrapper<ChatHistoryEntity> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(ChatHistoryEntity::getChatId, chatId).last("limit 1");
        ChatHistoryEntity entity = chatHistoryService.getOne(queryWrapper);
        ChatHistoryDetailResponse response = BeanUtil.copyProperties(entity, ChatHistoryDetailResponse.class, "createdAt", "updatedAt");
        response.setCreatedAt(entity.getCreatedAt().getEpochSecond());
        response.setUpdatedAt(entity.getUpdatedAt().getEpochSecond());
        return response;
    }

    @GetMapping("/clear")
    public String clear() {
        List<ChatItemEntity> list = chatItemService.getChatItemByUserId(getUserId());
        if (!list.isEmpty()) {
            List<String> chatIds = list.stream().map(ChatItemEntity::getChatId).toList();
            // 清空会话
            chatItemService.removeBatchByIds(list.stream().map(ChatItemEntity::getId).toList());
            // 清空会话历史
            chatHistoryService.remove(new LambdaQueryWrapper<ChatHistoryEntity>()
                    .in(ChatHistoryEntity::getChatId, chatIds)
                    .eq(ChatHistoryEntity::getUserId, getUserId())
            );
        }
        return Constants.SUCCESS;
    }

    @GetMapping("/history")
    public List<ChatHistoryDetailResponse> history(@RequestParam(value = "chat_id") String chatId) {
        LambdaQueryWrapper<ChatHistoryEntity> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(ChatHistoryEntity::getChatId, chatId);
        List<ChatHistoryEntity> list = chatHistoryService.list(queryWrapper);
        List<ChatHistoryDetailResponse> responseList = new ArrayList<>();
        for (ChatHistoryEntity item : list) {
            ChatHistoryDetailResponse vo = BeanUtil.copyProperties(item, ChatHistoryDetailResponse.class, "createdAt", "updatedAt");
            if ("prompt".equals(item.getType())) {
                UserEntity userEntity = userService.getById(item.getUserId());
                if (userEntity != null) {
                    vo.setIcon(userEntity.getAvatar());
                }
            } else if ("reply".equals(item.getType())) {
                ChatRoleEntity chatRoleEntity = chatRoleService.getById(item.getRoleId());
                if (chatRoleEntity != null) {
                    vo.setIcon(chatRoleEntity.getIcon());
                }
            }
            vo.setCreatedAt(item.getCreatedAt().getEpochSecond());
            vo.setUpdatedAt(item.getUpdatedAt().getEpochSecond());
            responseList.add(vo);
        }
        return responseList;
    }

    @PostMapping("/update")
    public String update(@RequestBody ChatItemUpdateRequest request) {
        LambdaUpdateWrapper<ChatItemEntity> updateWrapper = new LambdaUpdateWrapper<>();
        updateWrapper.eq(ChatItemEntity::getChatId, request.getChatId());
        updateWrapper.eq(ChatItemEntity::getUserId, getUserId());
        updateWrapper.set(ChatItemEntity::getTitle, request.getTitle());
        boolean flag = chatItemService.update(updateWrapper);
        if (!flag) {
            throw new BadRequestException("更新失败");
        }
        return Constants.SUCCESS;
    }

    @GetMapping("/remove")
    public String update(@RequestParam(value = "chat_id") String chatId) {
        LambdaQueryWrapper<ChatItemEntity> queryWrapper = new LambdaQueryWrapper<>();
        queryWrapper.eq(ChatItemEntity::getChatId, chatId);
        queryWrapper.eq(ChatItemEntity::getUserId, getUserId());
        boolean flag = chatItemService.remove(queryWrapper);
        if (!flag) {
            throw new BadRequestException("删除失败");
        }
        // 删除历史记录
        chatHistoryService.remove(new LambdaQueryWrapper<ChatHistoryEntity>()
                .eq(ChatHistoryEntity::getChatId, chatId)
                .eq(ChatHistoryEntity::getUserId, getUserId()));
        return Constants.SUCCESS;
    }

    @GetMapping("/stop")
    public String stop(@RequestParam(value = "session_id") String sessionId) {
        chatWebSocketHandler.stop(sessionId);
        return Constants.SUCCESS;
    }
}
